# TODO: add more tests

import torch
from models import *
from mappings import cnn2fc, cnn2lc

def test_prime():
    x = torch.randn(5, 3, 32, 32)
    torch.manual_seed(0)
    net = skinny(num_classes=10)
    torch.manual_seed(0)
    netprime = skinnyprime(num_classes=10)
    net.eval()
    netprime.eval()
    print((net(x) - netprime(x)).norm().item())


def test_fc(device):
    x = torch.randn(100, 3, 32, 32).to(device)
    torch.manual_seed(0)
    net_ = skinnyprime(num_classes=10).to(device)
    net = cnn2fc(net_).to(device)
    net_.eval()
    net.eval()
    print((net(x) - net_(x)).norm().item())

def test_lc(device):
    x = torch.randn(100, 3, 32, 32).to(device)
    torch.manual_seed(0)
    net_ = skinnyprime(num_classes=10).to(device)
    net = cnn2lc(net_).to(device)
    net_.eval()
    net.eval()
    print((net(x) - net_(x)).norm().item())

def test_train():
    import torch.optim as optim
    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    # torch.manual_seed(0)
    # net = alexnet(num_classes=10).to(device)
    # net.eval()
    # opt = optim.SGD(net.parameters(), lr=0.1)

    torch.manual_seed(0)
    net_ = skinnyprime(num_classes=10).to(device)
    net_.eval()
    opt_ = optim.SGD(net_.parameters(), lr=0.1)

    torch.manual_seed(0)
    net = cnn2fc(net_).to(device)
    net.eval()
    # db()
    opt = optim.SGD(net.parameters(), lr=0.1)
    crit = nn.CrossEntropyLoss().to(device)

    # ALSO run a randomly initialized fc as well.
    for x, y in train_loader:
        # NOTE: nets are still in eval mode! (so no effective dropout...)
        x, y = x.to(device), y.to(device)
        # confirms that as they train the loss diff is zero...
        print('diff', (net(x) - net_(x)).norm().item())
        # step original
        opt.zero_grad()
        out = net(x)
        loss = crit(out, y)
        loss.backward()
        opt.step()
        # step modified
        opt_.zero_grad()
        out_ = net_(x)
        loss_ = crit(out_, y)
        loss_.backward()
        opt_.step()
        print(loss.item(), loss_.item())

# def test():
#     x = torch.randn(5, 3, 32, 32)
#     torch.manual_seed(0)
#     net = SkinnyCNNPrime(num_classes=10)
#     net_sequential = list(net.named_children())[0][1]
#     sizes = [net.input_size] + net.sizes
#     size_pairs = [(sizes[i - 1][-1], sizes[i][-1]) for i in range(1, len(sizes))]
#     layers = []
#     for size, module in zip(size_pairs, net_sequential):
#         if module.__class__ == nn.Conv2d:
#             assert module.padding == (0, 0)
#             d_in, d_out = size
#             k, s = module.kernel_size[0], module.stride[0]
#             ch_in, ch_out = module.in_channels, module.out_channels
#             lin_in, lin_out = ch_in*d_in*d_in, ch_out*d_out*d_out
#             with_bias = module.bias is not None
#             lin = nn.Linear(lin_in, lin_out, bias=with_bias)
#             conv_W = list(module.parameters())[0]
#             lin_W = list(lin.parameters())[0]
#             lin_W.data.zero_()
#             for idx in range(ch_out):
#                 for i in range(d_out):
#                     for j in range(d_out):
#                         reverse_map = lin_W.view(ch_out, d_out, d_out, ch_in, d_in, d_in)[idx][i, j, :, :]
#                         reverse_map[:, i*s:i*s+k, j*s:j*s+k].copy_(conv_W[idx])
#             if with_bias:
#                 conv_B = list(module.parameters())[1]
#                 lin_B = list(lin.parameters())[1]
#                 lin_B.data.copy_(conv_B.expand(d_out * d_out, ch_out).t().contiguous().view(lin_out))
#             layers.append(Reshape([lin_in]))
#             layers.append(lin)
#             layers.append(Reshape([ch_out, d_out, d_out]))
#         elif module.__class__ == nn.ReLU or nn.Linear or nn.Dropout or Reshape or nn.MaxPool2d or nn.ZeroPad2d:
#             layers.append(copy.deepcopy(module))
#     new_net = nn.Sequential(*layers) 
#     new_net.eval()
#     net.eval()
#     print((net(x) - new_net(x)).norm().item())


if __name__ == '__main__':

    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    torch.set_default_tensor_type('torch.DoubleTensor')

    test_fc(device)
    test_lc(device)
    # test3()
    # test4()


